# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

import json
import urllib.parse

from knack.util import CLIError
from knack.log import get_logger

import azure.core.rest

from azure.mgmt.cognitiveservices.models import Account as CognitiveServicesAccount, Sku, \
    VirtualNetworkRule, IpRule, NetworkRuleSet, NetworkRuleAction, \
    AccountProperties as CognitiveServicesAccountProperties, ApiProperties as CognitiveServicesAccountApiProperties, \
    Identity, ResourceIdentityType as IdentityType, \
    Deployment, DeploymentModel, DeploymentScaleSettings, DeploymentProperties, \
    CommitmentPlan, CommitmentPlanProperties, CommitmentPeriod, \
    ConnectionPropertiesV2BasicResource, ConnectionUpdateContent, \
    Project, ProjectProperties
from azure.cli.command_modules.cognitiveservices._client_factory import cf_accounts, cf_resource_skus
from azure.cli.core.azclierror import BadRequestError
from azure.cli.command_modules.cognitiveservices._utils import load_connection_from_source, compose_identity

logger = get_logger(__name__)


def list_resources(client, resource_group_name=None):
    """
    List all Azure Cognitive Services accounts.
    """
    if resource_group_name:
        return client.list_by_resource_group(resource_group_name)
    return client.list()


def recover(client, location, resource_group_name, account_name):
    """
    Recover a deleted Azure Cognitive Services account.
    """
    properties = CognitiveServicesAccountProperties()
    properties.restore = True
    params = CognitiveServicesAccount(properties=properties)
    params.location = location

    return client.begin_create(resource_group_name, account_name, params)


def list_usages(client, resource_group_name, account_name):
    """
    List usages for Azure Cognitive Services account.
    """
    return client.list_usages(resource_group_name, account_name).value


def list_kinds(client):
    """
    List all valid kinds for Azure Cognitive Services account.

    :param client: the ResourceSkusOperations
    :return: a list
    """
    # The client should be ResourceSkusOperations, and list() should return a list of SKUs for all regions.
    # The sku will have "kind" and we use that to extract full list of kinds.
    kinds = {x.kind for x in client.list()}
    return sorted(list(kinds))


def list_skus(
    cmd, kind=None, location=None, resource_group_name=None, account_name=None
):
    """
    List skus for Azure Cognitive Services account.
    """
    if resource_group_name is not None or account_name is not None:
        logger.warning(
            "list-skus with an existing account has been deprecated and will be removed in a future release."
        )
        if resource_group_name is None:
            # account_name must not be None
            raise CLIError("--resource-group is required when --name is specified.")
        # keep the original behavior to avoid breaking changes
        return cf_accounts(cmd.cli_ctx).list_skus(resource_group_name, account_name)

    # in other cases, use kind and location to filter SKUs
    def _filter_sku(_sku):
        if kind is not None:
            if _sku.kind != kind:
                return False
        if location is not None:
            if location.lower() not in [x.lower() for x in _sku.locations]:
                return False
        return True

    return [x for x in cf_resource_skus(cmd.cli_ctx).list() if _filter_sku(x)]


def _is_valid_kind_change(current_kind, target_kind):
    valid_upgrades = {"AIServices": ["OpenAI"], "OpenAI": ["AIServices"]}
    return target_kind in valid_upgrades.get(current_kind, [])


def _kind_uses_project_management(kind):
    return kind in ["AIServices"]


def create(
    client,
    resource_group_name,
    account_name,
    sku_name,
    kind,
    location,
    custom_domain=None,
    tags=None,
    api_properties=None,
    assign_identity=False,
    storage=None,
    encryption=None,
    allow_project_management=None,
    yes=None,
):  # pylint: disable=unused-argument
    """
    Create an Azure Cognitive Services account.
    """

    sku = Sku(name=sku_name)

    if _kind_uses_project_management(kind) and allow_project_management is None:
        allow_project_management = True

    properties = CognitiveServicesAccountProperties()
    if api_properties is not None:
        api_properties = CognitiveServicesAccountApiProperties.deserialize(
            api_properties
        )
        properties.api_properties = api_properties
    if custom_domain:
        properties.custom_sub_domain_name = custom_domain

    if storage is not None:
        properties.user_owned_storage = json.loads(storage)

    if encryption is not None:
        properties.encryption = json.loads(encryption)

    properties.allow_project_management = allow_project_management
    params = CognitiveServicesAccount(sku=sku, kind=kind, location=location,
                                      properties=properties, tags=tags)
    if assign_identity or allow_project_management:
        params.identity = Identity(type=IdentityType.SYSTEM_ASSIGNED)

    return client.begin_create(resource_group_name, account_name, params)


def update(
    client,
    resource_group_name,
    account_name,
    sku_name=None,
    custom_domain=None,
    tags=None,
    api_properties=None,
    storage=None,
    encryption=None,
    allow_project_management=None,
    kind=None,
):
    """
    Update an Azure Cognitive Services account.
    """
    sa = None
    if sku_name is None:
        sa = client.get(resource_group_name, account_name)
        sku_name = sa.sku.name

    sku = Sku(name=sku_name)

    properties = CognitiveServicesAccountProperties()
    if api_properties is not None:
        api_properties = CognitiveServicesAccountApiProperties.deserialize(
            api_properties
        )
        properties.api_properties = api_properties
    if custom_domain:
        properties.custom_sub_domain_name = custom_domain
    if allow_project_management is not None:
        properties.allow_project_management = allow_project_management
    if storage is not None:
        properties.user_owned_storage = json.loads(storage)
    if encryption is not None:
        properties.encryption = json.loads(encryption)

    if kind is not None:
        if sa is None:
            sa = client.get(resource_group_name, account_name)
        if kind != sa.kind and not _is_valid_kind_change(sa.kind, kind):
            raise BadRequestError("Changing the account kind from '{}' to '{}' is not supported.".format(sa.kind, kind))
        if _kind_uses_project_management(kind) and allow_project_management is None:
            properties.allow_project_management = True

    params = CognitiveServicesAccount(kind=kind, sku=sku, properties=properties, tags=tags)

    return client.begin_update(resource_group_name, account_name, params)


def default_network_acls():
    rules = NetworkRuleSet()
    rules.default_action = NetworkRuleAction.deny
    rules.ip_rules = []
    rules.virtual_network_rules = []
    return rules


def list_network_rules(client, resource_group_name, account_name):
    """
    List network rules for Azure Cognitive Services account.
    """
    sa = client.get(resource_group_name, account_name)
    rules = sa.properties.network_acls
    if rules is None:
        rules = default_network_acls()
    return rules


def add_network_rule(
    client,
    resource_group_name,
    account_name,
    subnet=None,
    vnet_name=None,
    ip_address=None,
):  # pylint: disable=unused-argument
    """
    Add a network rule for Azure Cognitive Services account.
    """
    sa = client.get(resource_group_name, account_name)
    rules = sa.properties.network_acls
    if rules is None:
        rules = default_network_acls()

    if subnet:
        from azure.mgmt.core.tools import is_valid_resource_id

        if not is_valid_resource_id(subnet):
            raise CLIError(
                "Expected fully qualified resource ID: got '{}'".format(subnet)
            )

        if not rules.virtual_network_rules:
            rules.virtual_network_rules = []
        rules.virtual_network_rules.append(
            VirtualNetworkRule(id=subnet, ignore_missing_vnet_service_endpoint=True)
        )
    if ip_address:
        if not rules.ip_rules:
            rules.ip_rules = []
        rules.ip_rules.append(IpRule(value=ip_address))

    properties = CognitiveServicesAccountProperties()
    properties.network_acls = rules
    params = CognitiveServicesAccount(properties=properties)

    return client.begin_update(resource_group_name, account_name, params)


def remove_network_rule(
    client,
    resource_group_name,
    account_name,
    ip_address=None,
    subnet=None,
    vnet_name=None,
):  # pylint: disable=unused-argument
    """
    Remove a network rule for Azure Cognitive Services account.
    """
    sa = client.get(resource_group_name, account_name)
    rules = sa.properties.network_acls
    if rules is None:
        # nothing to update, but return the object
        return client.update(resource_group_name, account_name)

    if subnet:
        rules.virtual_network_rules = [
            x for x in rules.virtual_network_rules if not x.id.endswith(subnet)
        ]
    if ip_address:
        rules.ip_rules = [x for x in rules.ip_rules if x.value != ip_address]

    properties = CognitiveServicesAccountProperties()
    properties.network_acls = rules
    params = CognitiveServicesAccount(properties=properties)

    return client.begin_update(resource_group_name, account_name, params)


def identity_assign(client, resource_group_name, account_name):
    """
    Assign the identity for Azure Cognitive Services account.
    """
    params = CognitiveServicesAccount()
    params.identity = Identity(type=IdentityType.SYSTEM_ASSIGNED)
    sa = client.begin_update(resource_group_name, account_name, params).result()
    return sa.identity if sa.identity else {}


def identity_remove(client, resource_group_name, account_name):
    """
    Remove the identity for Azure Cognitive Services account.
    """
    params = CognitiveServicesAccount()
    params.identity = Identity(type=IdentityType.NONE)
    return client.begin_update(resource_group_name, account_name, params)


def identity_show(client, resource_group_name, account_name):
    """
    Show the identity for Azure Cognitive Services account.
    """
    sa = client.get(resource_group_name, account_name)
    return sa.identity if sa.identity else {}


def deployment_begin_create_or_update(
        client, resource_group_name, account_name, deployment_name,
        model_format, model_name, model_version, model_source=None,
        sku_name=None, sku_capacity=None,
        scale_settings_scale_type=None, scale_settings_capacity=None,
        spillover_deployment_name=None):
    """
    Create a deployment for Azure Cognitive Services account.
    """
    dpy = Deployment()
    dpy.properties = DeploymentProperties()
    dpy.properties.model = DeploymentModel()
    dpy.properties.model.format = model_format
    dpy.properties.model.name = model_name
    dpy.properties.model.version = model_version
    if model_source is not None:
        dpy.properties.model.source = model_source
    if sku_name is not None:
        dpy.sku = Sku(name=sku_name)
        dpy.sku.capacity = sku_capacity
    if scale_settings_scale_type is not None:
        dpy.properties.scale_settings = DeploymentScaleSettings()
        dpy.properties.scale_settings.scale_type = scale_settings_scale_type
        dpy.properties.scale_settings.capacity = scale_settings_capacity
    if spillover_deployment_name is not None:
        dpy.properties.spillover_deployment_name = spillover_deployment_name
    return client.begin_create_or_update(resource_group_name, account_name, deployment_name, dpy, polling=False)


def commitment_plan_create_or_update(
    client,
    resource_group_name,
    account_name,
    commitment_plan_name,
    hosting_model,
    plan_type,
    auto_renew,
    current_tier=None,
    current_count=None,
    next_tier=None,
    next_count=None,
):
    """
    Create a commitment plan for Azure Cognitive Services account.
    """
    plan = CommitmentPlan()
    plan.properties = CommitmentPlanProperties()
    plan.properties.hosting_model = hosting_model
    plan.properties.plan_type = plan_type
    if current_tier is not None or current_count is not None:
        plan.properties.current = CommitmentPeriod()
        plan.properties.current.tier = current_tier
        plan.properties.current.count = current_count
    if next_tier is not None or next_count is not None:
        plan.properties.next = CommitmentPeriod()
        plan.properties.next.tier = next_tier
        plan.properties.next.count = next_count
    plan.properties.auto_renew = auto_renew
    return client.create_or_update(
        resource_group_name, account_name, commitment_plan_name, plan
    )


AGENT_API_VERSION_PARAMS = {"api-version": "2025-11-15-preview"}


def _create_agent_request(
    method: str,
    agent_name: str,
    agent_version: str = None,
    *,
    container: bool = False,
    action: str = None,
    body: dict = None,
):
    if container and not agent_version:
        raise ValueError("container=True requires agent_version to be specified")

    if agent_version:
        url = f"/agents/{urllib.parse.quote(agent_name)}/versions/{urllib.parse.quote(agent_version)}"
        if container:
            url += "/containers/default"
    else:
        url = f"/agents/{urllib.parse.quote(agent_name)}"

    if action:
        url += f":{action}"
    return azure.core.rest.HttpRequest(
        method, url, json=body, params=AGENT_API_VERSION_PARAMS
    )


def _invoke_agent_container_operation(
    client,
    agent_name,
    agent_version,
    *,
    action: str,
    min_replicas=None,
    max_replicas=None,
):
    request_body = {}
    if min_replicas is not None:
        request_body["min_replicas"] = min_replicas
    if max_replicas is not None:
        request_body["max_replicas"] = max_replicas
    request = _create_agent_request(
        "POST",
        agent_name,
        agent_version,
        action=action,
        container=True,
        body=request_body,
    )
    response = client.send_request(request)
    response.raise_for_status()
    return response.json()


def agent_update(
    client,
    account_name,
    project_name,
    agent_name,
    agent_version,
    min_replicas=None,
    max_replicas=None,
    description=None,
    tags=None,
):  # pylint: disable=unused-argument
    """
    Update hosted agent deployment configuration.
    Updates horizontal scale configuration (min and max replica), agent meta-data such as description and tags.
    New version is not created for this update.
    """
    return _invoke_agent_container_operation(
        client,
        agent_name,
        agent_version,
        action="update",
        min_replicas=min_replicas,
        max_replicas=max_replicas,
    )


def agent_stop(
    client, account_name, project_name, agent_name, agent_version
):  # pylint: disable=unused-argument
    """
    Stop hosted agent deployment.
    """
    return _invoke_agent_container_operation(
        client, agent_name, agent_version, action="stop"
    )


def agent_start(
    client, account_name, project_name, agent_name, agent_version
):  # pylint: disable=unused-argument
    """
    Start hosted agent deployment.
    """
    return _invoke_agent_container_operation(
        client, agent_name, agent_version, action="start"
    )


def agent_delete_deployment(
    client, account_name, project_name, agent_name, agent_version
):  # pylint: disable=unused-argument
    """
    Delete hosted agent deployment.
    Deletes the agent deployment only, agent version associated with the deployment remains.
    """
    request = _create_agent_request(
        "POST", agent_name, agent_version, action="delete", container=True
    )
    response = client.send_request(request)
    response.raise_for_status()
    return response.json()


def agent_delete(
    client, account_name, project_name, agent_name, agent_version=None
):  # pylint: disable=unused-argument
    """
    Delete hosted agent version or all versions.
    If agent_version is provided, deletes the agent instance and agent definition associated with that version.
    If agent_version is not provided, deletes all agent instances and agent definitions associated with the agent name.
    """
    request = _create_agent_request("DELETE", agent_name, agent_version)
    response = client.send_request(request)
    response.raise_for_status()
    return response.json()


def agent_list(client, account_name, project_name):  # pylint: disable=unused-argument
    """
    List agents.
    """
    agents = []
    params = AGENT_API_VERSION_PARAMS.copy()
    while True:
        request = azure.core.rest.HttpRequest("GET", "/agents", params=params)
        response = client.send_request(request)
        response.raise_for_status()
        body = response.json()
        agents.extend(body.get("data", []))
        if body.get("has_more"):
            params["after"] = body.get("last_id")
        else:
            return agents


def agent_versions_list(
    client, account_name, project_name, agent_name
):  # pylint: disable=unused-argument
    """
    List all versions of a hosted agent.
    """
    versions = []
    params = AGENT_API_VERSION_PARAMS.copy()
    while True:
        request = azure.core.rest.HttpRequest(
            "GET", f"/agents/{urllib.parse.quote(agent_name)}/versions", params=params
        )
        response = client.send_request(request)
        response.raise_for_status()
        body = response.json()
        versions.extend(body.get("data", []))
        if body.get("has_more"):
            params["after"] = body.get("last_id")
        else:
            return versions


def agent_show(
    client, account_name, project_name, agent_name
):  # pylint: disable=unused-argument
    """
    Show details of a hosted agent.
    """
    request = azure.core.rest.HttpRequest(
        "GET",
        f"/agents/{urllib.parse.quote(agent_name)}",
        params=AGENT_API_VERSION_PARAMS,
    )
    response = client.send_request(request)
    response.raise_for_status()
    return response.json()


def project_create(
        client,
        resource_group_name,
        account_name,
        project_name,
        location,
        assign_identity=False,
        user_assigned_identity=None,
        description=None,
        display_name=None,
        no_wait=False,
):
    """
    Create a project for Azure Cognitive Services account.
    """
    project = Project(properties=ProjectProperties(display_name=display_name, description=description))
    project.location = location
    if user_assigned_identity is None:
        assign_identity = True
    project.identity = compose_identity(system_assigned=assign_identity, user_assigned_identity=user_assigned_identity)
    return client.begin_create(resource_group_name, account_name, project_name, project, polling=no_wait)


def project_update(
    client,
    resource_group_name,
    account_name,
    project_name,
    description=None,
    display_name=None,
):
    """
    Update a project for Azure Cognitive Services account.
    """
    project_props = ProjectProperties()
    if description is not None:
        project_props.description = description
    if display_name is not None:
        project_props.display_name = display_name
    project = Project(properties=project_props)
    return client.begin_update(resource_group_name, account_name, project_name, project)


def account_connection_create(
    client,
    resource_group_name,
    account_name,
    connection_name,
    file,
):
    """
    Create a connection for Azure Cognitive Services account.
    """
    account_connection_properties = load_connection_from_source(source=file)
    account_connection = ConnectionPropertiesV2BasicResource(properties=account_connection_properties)

    return client.create(
        resource_group_name,
        account_name,
        connection_name,
        account_connection)


# This function is intended to be used with the 'generic_update_command' per
# https://github.com/Azure/azure-cli/blob/0b06b4f295766bcadaebdb7cf8fc05c7d6c9a5a8/doc/authoring_command_modules/authoring_commands.md#generic-update-commands
def account_connection_update(
    instance,
):
    """
    Update a connection for Azure Cognitive Services account.
    """
    account_connection = ConnectionUpdateContent(properties=instance.properties)
    return account_connection


def project_connection_create(
    client,
    resource_group_name,
    account_name,
    project_name,
    connection_name,
    file,
):
    """
    Create a connection for Azure Cognitive Services account.
    """
    project_connection_properties = load_connection_from_source(source=file)
    project_connection = ConnectionPropertiesV2BasicResource(properties=project_connection_properties)
    return client.create(
        resource_group_name,
        account_name,
        project_name,
        connection_name,
        project_connection)


# This function is intended to be used with the 'generic_update_command' per
# https://github.com/Azure/azure-cli/blob/0b06b4f295766bcadaebdb7cf8fc05c7d6c9a5a8/doc/authoring_command_modules/authoring_commands.md#generic-update-commands
def project_connection_update(
    instance,
):
    """
    Update a connection for Azure Cognitive Services account.
    """
    project_connection = ConnectionUpdateContent(properties=instance.properties)
    return project_connection
